#!/usr/bin/env python
"""
The MIT License

Copyright (c) 2008 Vijay Ganesh
              2014 Jurriaan Bremer, jurriaanbremer@gmail.com

Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:

The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""

import sys
# Make sure we can find the STP python package
sys.path.insert(0, "@PYTHON_INTERFACE_DIR@/stp")

import stp
import unittest
import ctypes


@stp.stp
def test0_foo(a, b):
    assert b != 0
    return (a + 42) % b


@stp.stp
def test0_bar(a, b):
    return a * b == 12345


class TestSTP(unittest.TestCase):
    def setUp(self):
        self.s = stp.Solver()
        self.stp_solvers = set(
            ["Minisat", "SimplifyingMinisat", "Cryptominisat", "Riss"]
        )

    def test_long(self):
        s = self.s
        a = s.bitvec('a', 256)
        b = s.bitvec('b', 256)
        c = s.bitvecvalD(256, "1337")
        d = s.bitvecvalD(256, "42")
        self.assertTrue(s.check(b.eq(d), a.add(b).eq(c)))
        self.assertEqual(s.model(), {'a': 1337-42, 'b': 42})

    def test_simple(self):
        s = self.s
        a = s.bitvec('a', 32)
        b = s.bitvec('b', 32)
        c = s.bitvecval(32, 1337)
        d = s.bitvecval(32, 42)
        self.assertTrue(s.check(b.eq(d), a.add(b).eq(c)))
        self.assertEqual(s.model(), {'a': 1337-42, 'b': 42})

    def test_equal(self):
      s = self.s
      a = s.bitvec('a', 32)
      b = s.bitvec('b', 32)
      self.assertFalse(s.check(b.eq(a), a.gt(b)))
      self.assertTrue(s.check(b.eq(a), a.eq(a)))

    def test_equal2(self):
      s = self.s
      a = s.bitvec('a', 32)
      b = s.bitvec('b', 32)
      self.assertFalse(s.check(b.eq(a), a.gt(b)))
      self.assertTrue(s.check(b.eq(a), a.eq(b)))

    def test_bitvecval(self):
        s = self.s
        a = s.bitvec('a', 32)
        b = s.bitvec('b', 32)
        self.assertTrue(s.check(a.add(b).eq(69)))
        self.assertEqual(s.model()['a'] + s.model()['b'], 69)

    def test_operator(self):
        s = self.s
        a = s.bitvec('a', 32)
        b = s.bitvec('b', 32)
        c = s.bitvec('c', 32)
        self.assertTrue(s.check(a + b + c == 123, b - c == 66))
        m = s.model()
        self.assertEqual((m['a']+m['b']+m['c'])%2**32, 123)
        self.assertEqual(m['b']-m['c'], 66)

        d = s.bitvec('d', 32)
        self.assertTrue(s.check((a << 1) - d == b))
        m = s.model()
        self.assertEqual((m['a'] << 1) - m['d'], b)

    def test_quick_model(self):
        s = self.s
        a = s.bitvec('a', 32)
        b = s.bitvec('b', 32)
        c = s.bitvec('c', 32)
        self.assertTrue(s.check(a + b + c == 666, b - c == 321, c != 666))
        self.assertEqual((s['a'] + s['b'] + s['c'])%2**32, 666)
        self.assertEqual((s['b'] - s['c'])%2**32, 321)

    def test_bitvec32(self):
        s = self.s
        a, b, c = s.bitvecs('a b c')
        s.add(a != 0, b != 0, c != 0, a != b, b != c, a != c)
        self.assertTrue(s.check(a * 2 + b * 2 == c * 2))
        self.assertEqual((s['a'] * 2 + s['b'] * 2)%2**32, s['c'] * 2 % 2**32)

    def test_boolean_expr(self):
        s = self.s
        a, b, c = s.bitvecs('a b c')
        s.add(a != b, a != c, b != c)
        self.assertTrue(s.check(s.or_(a + b == 1, a + c == 1)))
        self.assertTrue((s['a'] + s['b'])%(2**32) == 1 or (s['a'] + s['c'])%(2**32) == 1)

    def test_ast(self):
        models = [
            {'y': 1658330539, 'x': 1658330539},
            {'y': 3805814187, 'x': 3805814187},
            {'y': 2636636757, 'x': 2636636757},
            {'y': 489153109, 'x': 489153109},
        ]

        with stp.Solver() as s:
            x, y = s.bitvecs('x y')
            count = 0
            while s.check(test0_foo(x, y) == test0_foo(y, x),
                          test0_bar(x, y)):
                count += 1
                assert s.model() in models
                s.add(x != s.model('x'))

            assert count == 4

    def test_value(self):
        s = self.s
        a = s.bitvec('a')
        s.check(a == 0x41414141)
        self.assertEqual(a.value, 0x41414141)

    def test_empty_check(self):
        s = self.s
        a = s.bitvec('a')
        s.add(a == 2)
        self.assertTrue(s.check())

    def test_empty_check_wrong(self):
        s = self.s
        a = s.bitvec('a')
        s.add(a == 2)
        s.add(a == 3)
        self.assertFalse(s.check())

    def test_double_assume(self):
        s = self.s
        b = s.bitvec('b')
        self.assertTrue(s.check(b == 1))
        self.assertTrue(s.check(b == 2))

    def test_true_then_false(self):
      s = self.s
      a = s.bitvec('a', 32)
      b = s.bitvec('b', 32)
      self.assertTrue(s.check(a==b))
      self.assertTrue(s.check(a>b))

    def test_empty_check(self):
      s = self.s
      a = s.bitvec('a', 32)
      b = s.bitvec('b', 32)
      s.add(a==b)
      s.check()
      s.check(a>b)
      s.add(a>b)
      s.add(a==b)

    def test_non_bitvector_constant_assert_fail(self):
      s = self.s
      a = s.bitvec('a', 32)
      b = s.bitvec('b', 32)
      s.check(a==b)
      s.check(a>=b)
      s.check(a>b)
      s.check(a<b)
      a = a.add(b)
      s.check()
      s.check(a>b)
      b.add(a)

    def test_pushpop(self):
        s = self.s

        s.push()
        # .add is permanent - once an expression has been added, it cannot
        # be contradicted.
        a = s.bitvec('a')
        s.add(a == 2)
        self.assertTrue(s.check())
        self.assertEqual(s.model(), {'a': 2})

        # Contradicts the earlier .add formula.
        self.assertFalse(s.check(a == 3))

        # Contradicts the earlier .add formula.
        s.add(a == 3)
        self.assertFalse(s.check())
        s.pop()

        s.push()
        # .check is only used in the current query - an expression given
        # to .check is not persistent, and can contradict any earlier
        # expression that has been given to .check.
        b = s.bitvec('b')
        self.assertTrue(s.check(b == 1))
        self.assertTrue(s.check(b == 2))
        s.pop()

    def test_operator_overloading(self):
        s = self.s
        a, b, c = s.bitvecs('a b c')
        A, B = 42, 52
        s.add(a == 42, b == 52)

        # bv OP bv => bv
        for opname in "+ - * // % << >> & | ^".split():
            op = eval("lambda a, b: a %s b" % opname)

            # Test symbolic OP symbolic.
            s.check(c == op(a, b))
            self.assertEqual(s.model()['c'], op(A, B) % 2**32)
            s.check(c == op(b, a))
            self.assertEqual(s.model()['c'], op(B, A) % 2**32)

            # Test symbolic OP constant.
            s.check(c == op(a, B))
            self.assertEqual(s.model()['c'], op(A, B) % 2**32)

            # Test constant OP symbolic.
            s.check(c == op(B, a))
            self.assertEqual(s.model()['c'], op(B, A) % 2**32)

        # bv OP bv => bool
        for opname in "< > <= >= == !=".split():
            op = eval("lambda a, b: a %s b" % opname)

            # Test symbolic OP symbolic.
            self.assertEqual(s.check(op(a, b)), op(A, B))
            self.assertEqual(s.check(op(b, a)), op(B, A))

            # Test symbolic OP constant.
            self.assertEqual(s.check(op(a, B)), op(A, B))

            # Test constant OP symbolic.
            self.assertEqual(s.check(op(B, a)), op(B, A))

        # OP bv => bv
        for opname in "- + ~".split():
            op = eval("lambda a: %s a" % opname)

            s.check(c == op(a))
            self.assertEqual(s.model()['c'], op(A) % 2**32)

    def test_get_value_for_bitvecval(self):
        solver = self.s
        name_for_x = "x"
        value_for_x = 0
        x = solver.bitvec(name_for_x)
        self.assertEqual(x.name, name_for_x)
        self.assertEqual(x.value, value_for_x)

        name_for_bitvecval = None
        value_for_bitvecval = 18
        bitvecval = solver.bitvecval(32, value_for_bitvecval)
        self.assertEqual(bitvecval.name, name_for_bitvecval)
        self.assertEqual(bitvecval.value, value_for_bitvecval)

        #
        # Build up a complex expression and check we can look it up from the
        # model directly
        #
        expr_for_add = x.add(bitvecval.mul(2))
        value_from_model = solver.model(expr=expr_for_add)
        self.assertEqual(value_from_model, value_for_x + (value_for_bitvecval * 2))

    def test_urem(self):
        #
        # Validate the `urem` capability in STP -- note: `urem` == `mod`!
        #
        solver = self.s
        width = 32
        needle = 0xDEADBEEF
        x = solver.bitvec("x", width)
        y = solver.bitvec("y", width)
        expr = x.rem(y)
        solver.add(expr == needle)
        solver.add(y != 0) # avoid divide by zero
        self.assertTrue(solver.check(), "instance should be satisfiable")
        model = solver.model()
        val_for_x = model["x"]
        val_for_y = model["y"]
        concrete_value = val_for_x % val_for_y
        self.assertEqual(concrete_value, needle, "Concrete value did not match expected value")

    def _test_solver(self, solver, is_default_solver=False):
        """
        Helper method to validate that the passed in "solver" can be set and
        used
        """
        supports_method = getattr(self.s, "supports{:s}".format(solver))
        use_method = getattr(self.s, "use{:s}".format(solver))
        is_using_method = getattr(self.s, "isUsing{:s}".format(solver))

        supported = supports_method()

        if is_default_solver:
            self.assertTrue(
                supported,
                "{solver:s} is a default solver, which is reporting as disabled".format(
                    solver=solver
                ),
            )

        #
        # If the solver is supported, the "use" and "isUsing" methods should
        # return true -- similarly, if the solver is not supported, then they
        # both return false.
        #
        # We can use that as the condition for the assertion
        #
        self.assertEqual(
            use_method(),
            supported,
            "Enabling {solver:s} did not match its 'supported' status".format(
                solver=solver
            ),
        )
        self.assertEqual(
            is_using_method(),
            supported,
            "Checking if {solver:s} is in use did not match its 'supported' status".format(
                solver=solver
            ),
        )

        #
        # When we're using a supported solver, then after calling `use_method`,
        # all other solvers should not be in use
        #
        if supported:
            for other_solver in self.stp_solvers - set([solver]):
                other_is_using_method = getattr(
                    self.s, "isUsing{:s}".format(other_solver)
                )
                self.assertFalse(
                    other_is_using_method(),
                    "{other_solver:s} reported to be in use, while {solver:s} was previously selected".format(
                        solver=solver, other_solver=other_solver
                    ),
                )

    def test_minisat(self):
        """
        Checks that the underlying solver can be set to be minisat
        """
        self._test_solver("Minisat", is_default_solver=True)

    def test_simplifyingminisat(self):
        """
        Checks that the underlying solver can be set to be simplifyingminisat
        """
        self._test_solver("SimplifyingMinisat", is_default_solver=True)

    def test_cryptominisat(self):
        """
        Checks that the underlying solver can be set to be cryptominisat
        """
        self._test_solver("Cryptominisat", is_default_solver=False)

    def test_riss(self):
        """
        Checks that the underlying solver can be set to be riss
        """
        self._test_solver("Riss", is_default_solver=False)

    def test_git_version_sha(self):
        """
        Test verifies that the returned git sha is of non-zero length
        """
        self.assertNotEqual(len(stp.get_git_version_sha()), 0)

    def test_git_version_tag(self):
        """
        Test verifies that the returned git tag is of non-zero length
        """
        self.assertNotEqual(len(stp.get_git_version_tag()), 0)

    def test_compilation_env(self):
        """
        Test verifies that the returned compilation env is of non-zero length
        """
        self.assertNotEqual(len(stp.get_compilation_env()), 0)

    def test_lib(self):
        """
        Test verifies that the returned CDLL object is "well-formed"
        """
        stp_cdll = stp.get_lib()
        self.assertIsNotNone(stp_cdll, "get_lib should never return None")
        self.assertTrue(
            isinstance(stp_cdll, ctypes.CDLL), "get_lib should return a CDLL type"
        )
        self.assertTrue(
            callable(stp_cdll.vc_bvConcatExpr),
            "STP's CDLL object should have a method called 'vc_bvConcatExpr'",
        )

    # Check that wide values work.
    def test_wide0(self):
        s = self.s
        a = s.bitvec('a', 1500)
        b = s.bitvec('b', 1500)
        c = s.bitvecval(1500, 1337)
        d = s.bitvecval(1500, 42)
        self.assertTrue(s.check(b.eq(d), a.add(b).eq(c)))
        self.assertEqual(s.model(), {'a': 1337-42, 'b': 42})

    def test_wide1(self):
        s = self.s
        a = s.bitvec('a', 142)
        b = s.bitvecvalD(142, str(0x2FFFFFFFFFFF5FFFFFFFFFFFFFFFFFF1 ))
        c = s.bitvecvalD(142, str(0x2FFFFFFFFFFF5FFFFFFFFFFFFFFFFFF2 ))
        self.assertTrue(s.check(c.sub(b).eq(a)))
        self.assertEqual(s.model(), {'a': 1})

    def test_wide2(self):
        s = self.s
        a = s.bitvec('a', 142)
        v = 0x2FFFFFFFFFFFFFFF5FFFFFFFFFFFFFFFFFF1
        b = s.bitvecvalD(142, str(v))
        self.assertTrue(s.check(a.eq(b)))
        self.assertEqual(s.model(), {'a' : v})

    def test_wide3(self):
        s = self.s
        a = s.bitvec('a', 142)
        v = 0x2FFFFFFFFFFFFFFF5FFFFFFFFFFFFFFFFFF1
        # Using the integer function to create a bitvector.
        b = s.bitvecval(142, v)
        self.assertTrue(s.check(a.eq(b)))
        self.assertEqual(s.model(), {'a' : v})


if __name__ == '__main__':
    suite = unittest.TestLoader().loadTestsFromTestCase(TestSTP)
    result = unittest.TextTestRunner(verbosity=2).run(suite)
    sys.exit(not result.wasSuccessful())

# EOF
